{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example use of the quantization framework\n",
    "\n",
    "In this notebook, we show to use the simple quantization framework developed for this Large Language Model use cas example. \n",
    "\n",
    "This framework is used to handle floating points along their quantized versions. \n",
    "Essentially, it does so by propagating the quantization parameters such as the scales and zero points throughout all the operators found in the function. \n",
    "This function is then meant to be compiled using Concrete Python, creating a circuit that can be executed in FHE."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "import numpy as np\n",
    "from concrete.fhe.tracing import Tracer\n",
    "from quant_framework import DualArray, Quantizer\n",
    "\n",
    "from concrete import fhe"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Function definition \n",
    "\n",
    "Let us define the function to compile using Concrete Python. We also generate the input floating point values. \n",
    "These values will also be used for the calibration step, which essentially computes and stores the quantization parameters while executing the function.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "\n",
    "n_values = 10\n",
    "x_calib = np.random.randn(n_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the quantizer. Higher N_BITS will give better results for the floating point comparison\n",
    "N_BITS = 8\n",
    "quantizer = Quantizer(n_bits=N_BITS)\n",
    "\n",
    "\n",
    "def finalize(x):\n",
    "    \"\"\"Finalize the output value.\n",
    "\n",
    "    If the DualArray's integer array is a Tracer, an object used during compilation, return it\n",
    "    as is. Else, return the DualArray. This is called at the end of the run_numpy method because\n",
    "    the compiler can only consider Tracer objects or Numpy arrays as input and outputs.\n",
    "    \"\"\"\n",
    "    if isinstance(x.int_array, Tracer):\n",
    "        return x.int_array\n",
    "    return x\n",
    "\n",
    "\n",
    "def f(q_x):\n",
    "    \"\"\"Function made of quantized operators to compile.\"\"\"\n",
    "\n",
    "    # Convert the inputs to a DualArray instance using the stored calibration data. This is\n",
    "    # necessary as Concrete Python can only compile functions that inputs Numpy arrays, while we\n",
    "    # still want to be able to propagate the quantization parameters throughout the different\n",
    "    # operators\n",
    "    dual_x = DualArray(float_array=x_calib, int_array=q_x, quantizer=quantizer)\n",
    "    dual_y = dual_x.exp(key=\"exp\")\n",
    "    dual_x = dual_x.add(dual_y, \"add\")\n",
    "    dual_x = dual_x.truediv(3, key=\"truediv\")\n",
    "    dual_x = dual_x.quantize(key=\"output\")\n",
    "    return finalize(dual_x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us define the expected function using Numpy operators. in order to be able to compare the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expected_f(x):\n",
    "    \"\"\"Expected function made of float operators.\"\"\"\n",
    "    return (np.exp(x) + x) / 3"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calibrate the function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the input to a DualArray and quantize it\n",
    "dual_x = DualArray(x_calib, quantizer=quantizer).quantize(key=\"input\")\n",
    "\n",
    "# Calibrate the function in order to compute and store the quantization parameters\n",
    "dual_result = f(dual_x.int_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'add_sub_add_requant': (0.012416286069347535, 0),\n",
      " 'add_sub_add_self': (0.012416286069347535, 0),\n",
      " 'exp': (0.012416286069347535, 0),\n",
      " 'input': (0.012434746578798358, 0),\n",
      " 'output': (0.016833222749391915, 0),\n",
      " 'truediv': (0.012405140942975472, 0)}\n"
     ]
    }
   ],
   "source": [
    "# At this step, the quantizer contains all the needed quantization parameters (scale, zero_point)\n",
    "pprint(quantizer.scale_dict)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clear result comparison\n",
    "\n",
    "We can now compare the expected result with the float result computed bu the quantized function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE between the expected result and the computed float result: 0.0051\n"
     ]
    }
   ],
   "source": [
    "expected_output = expected_f(x_calib)\n",
    "\n",
    "# De-quantize the output\n",
    "float_output = dual_result.dequantize(\"output\").float_array\n",
    "\n",
    "# Compare the float output with the expected output using the MAE score\n",
    "output_mae = np.mean(np.abs(float_output - expected_output))\n",
    "\n",
    "print(f\"MAE between the expected result and the computed float result: {output_mae:.4f}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the function is fully calibrated, it is possible to use it with integer values only as well on a new input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE between the expected result and the computed float result: 0.0028\n"
     ]
    }
   ],
   "source": [
    "# Create a new input and quantize it\n",
    "new_x = np.random.randn(10)\n",
    "new_int_x = quantizer.quantize(new_x, key=\"input\")\n",
    "\n",
    "# Retrieve the integer result\n",
    "new_int_result = f(new_int_x).int_array\n",
    "\n",
    "# Dequantize the output\n",
    "new_float_output = quantizer.dequantize(new_int_result, key=\"output\")\n",
    "\n",
    "# Compare the float output with the expected output using the MAE score\n",
    "new_expected_output = expected_f(new_x)\n",
    "new_output_mae = np.mean(np.abs(new_float_output - new_expected_output))\n",
    "\n",
    "print(f\"MAE between the expected result and the computed float result: {new_output_mae:.4f}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compilation with Concrete Python\n",
    "\n",
    "Since the function can work with integer values only, it is possible to compile it and build an FHE circuit using Concrete Python. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Computation Graph\n",
      "--------------------------------------------------------------------------------\n",
      " %0 = q_x                      # EncryptedScalar<int8>          ∈ [-38, 127]\n",
      " %1 = 0                        # ClearScalar<uint1>             ∈ [0, 0]\n",
      " %2 = subtract(%0, %1)         # EncryptedScalar<int8>          ∈ [-38, 127]\n",
      " %3 = 0                        # ClearScalar<uint1>             ∈ [0, 0]\n",
      " %4 = subtract(%0, %3)         # EncryptedScalar<int8>          ∈ [-38, 127]\n",
      " %5 = subgraph(%4)             # EncryptedScalar<int8>          ∈ [-38, 127]\n",
      " %6 = subgraph(%2)             # EncryptedScalar<uint9>         ∈ [50, 390]\n",
      " %7 = 1                        # ClearScalar<uint1>             ∈ [1, 1]\n",
      " %8 = multiply(%7, %6)         # EncryptedScalar<uint9>         ∈ [50, 390]\n",
      " %9 = add(%5, %8)              # EncryptedScalar<uint10>        ∈ [12, 517]\n",
      "%10 = 0                        # ClearScalar<uint1>             ∈ [0, 0]\n",
      "%11 = subtract(%9, %10)        # EncryptedScalar<uint10>        ∈ [12, 517]\n",
      "%12 = subgraph(%11)            # EncryptedScalar<uint7>         ∈ [3, 127]\n",
      "return %12\n",
      "\n",
      "Subgraphs:\n",
      "\n",
      "    %5 = subgraph(%4):\n",
      "\n",
      "        %0 = input                         # EncryptedScalar<uint1>\n",
      "        %1 = 0.012416286069347535          # ClearScalar<float64>\n",
      "        %2 = multiply(%0, %1)              # EncryptedScalar<float64>\n",
      "        %3 = 0.012416286069347535          # ClearScalar<float64>\n",
      "        %4 = divide(%2, %3)                # EncryptedScalar<float64>\n",
      "        %5 = 0                             # ClearScalar<uint1>\n",
      "        %6 = add(%4, %5)                   # EncryptedScalar<float64>\n",
      "        %7 = rint(%6)                      # EncryptedScalar<float64>\n",
      "        %8 = astype(%7, dtype=int_)        # EncryptedScalar<uint1>\n",
      "        return %8\n",
      "\n",
      "    %6 = subgraph(%2):\n",
      "\n",
      "        %0 = input                         # EncryptedScalar<uint1>\n",
      "        %1 = 0.012416286069347535          # ClearScalar<float64>\n",
      "        %2 = multiply(%0, %1)              # EncryptedScalar<float64>\n",
      "        %3 = exp(%2)                       # EncryptedScalar<float64>\n",
      "        %4 = 0.012416286069347535          # ClearScalar<float64>\n",
      "        %5 = divide(%3, %4)                # EncryptedScalar<float64>\n",
      "        %6 = 0                             # ClearScalar<uint1>\n",
      "        %7 = add(%5, %6)                   # EncryptedScalar<float64>\n",
      "        %8 = rint(%7)                      # EncryptedScalar<float64>\n",
      "        %9 = astype(%8, dtype=int_)        # EncryptedScalar<uint1>\n",
      "        return %9\n",
      "\n",
      "    %12 = subgraph(%11):\n",
      "\n",
      "         %0 = input                         # EncryptedScalar<uint1>\n",
      "         %1 = 0.012405140942975472          # ClearScalar<float64>\n",
      "         %2 = multiply(%0, %1)              # EncryptedScalar<float64>\n",
      "         %3 = 3                             # ClearScalar<uint2>\n",
      "         %4 = divide(%2, %3)                # EncryptedScalar<float64>\n",
      "         %5 = 0.016833222749391915          # ClearScalar<float64>\n",
      "         %6 = divide(%4, %5)                # EncryptedScalar<float64>\n",
      "         %7 = 0                             # ClearScalar<uint1>\n",
      "         %8 = add(%6, %7)                   # EncryptedScalar<float64>\n",
      "         %9 = rint(%8)                      # EncryptedScalar<float64>\n",
      "        %10 = astype(%9, dtype=int_)        # EncryptedScalar<uint1>\n",
      "        return %10\n",
      "--------------------------------------------------------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Instantiate the compiler\n",
    "compiler = fhe.Compiler(f, {\"q_x\": \"encrypted\"})\n",
    "\n",
    "# Build the inputset as a batch of single quantized input\n",
    "inputset = list(quantizer.quantize(x_calib, key=\"input\"))\n",
    "\n",
    "# Compile the function using the inputset and print the computation graph\n",
    "circuit = compiler.compile(inputset, show_graph=True)\n",
    "\n",
    "# Generate the keys\n",
    "circuit.keygen()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can evaluate the function in FHE using simulation or not, and then compare these results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 3, 3)\n",
      "Results are identical: True\n"
     ]
    }
   ],
   "source": [
    "input_0 = new_int_x[0]\n",
    "\n",
    "# Compute the result in the clear directly using the quantized operators\n",
    "clear_evaluation = f(input_0)\n",
    "\n",
    "# Compute the result in the clear using FHE simulation\n",
    "simulated_evaluation = circuit.simulate(input_0)\n",
    "\n",
    "# Compute the result in FHE\n",
    "fhe_evaluation = circuit.encrypt_run_decrypt(input_0)\n",
    "\n",
    "print((clear_evaluation.int_array, simulated_evaluation, fhe_evaluation))\n",
    "print(\n",
    "    \"Results are identical:\",\n",
    "    all((clear_evaluation.int_array, simulated_evaluation, fhe_evaluation)),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
